""" Utility Functions """
from functools import partial
from typing import Any, Tuple, List, Dict, Union, Type, Optional, Callable

import jax
import jax.numpy as jnp
from jax import nn
import haiku as hk


# ============================== Activation Functions ============================= #

@jax.jit
def act_mish(x: jax.Array) -> jax.Array:
    return x * jnp.tanh(jax.nn.softplus(x))

@jax.jit
def act_gelu_new(x: jax.Array) -> jax.Array:
    return 0.5 * x *  (1.0 + jnp.tanh(jnp.sqrt(2.0 / np.pi) * (x + 0.044715 * (x ** 3))))

act2fn = {
    'mish': act_mish,
    'gelu': nn.gelu,
    'gelu_new': act_gelu_new, 
    'relu': nn.relu,
    'tanh': jnp.tanh,
}

# ============================== Weight Initializers ============================= #

def init_lecun_uniform() -> Dict[str, hk.initializers.Initializer]:
    return {"w_init": hk.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='uniform'),
            "b_init": hk.initializers.Constant(constant=0.)}

def init_lecun_normal() -> Dict[str, hk.initializers.Initializer]:
    return {"w_init": hk.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='truncated_normal'),
            "b_init": hk.initializers.Constant(constant=0.)}

def init_he_uniform() -> Dict[str, hk.initializers.Initializer]:
    return {"w_init": hk.initializers.VarianceScaling(scale=2.0, mode='fan_in', distribution='uniform'),
            "b_init": hk.initializers.Constant(constant=0.)}

def init_he_normal() -> Dict[str, hk.initializers.Initializer]:
    return {"w_init": hk.initializers.VarianceScaling(scale=2.0, mode='fan_in', distribution='truncated_normal'),
            "b_init": hk.initializers.Constant(constant=0.)}

# ============================== Nosing Functions ============================= #

@jax.jit
def fn_foward_noise(x: jax.Array, sqrtab: jax.Array, sqrtmab: jax.Array, rng=None):
    noise = jax.random.normal(rng, shape=x.shape)
    x = sqrtab * x + sqrtmab * noise
    return x, noise

# ============================== Loss Functions ============================= #

@jax.jit
def kl_div(p_mean: jax.Array, p_std: jax.Array, q_mean: jax.Array = None, q_std: jax.Array = None):
    if q_mean is None:
        return 0.5 * jnp.mean(-jnp.log(jnp.square(p_std)) - 1. + jnp.square(p_std) + jnp.square(p_mean), axis=-1)
    return jnp.mean((jnp.log(p_std) - jnp.log(q_std)) + ((jnp.square(q_std) + jnp.square(q_mean - p_mean)) / (2 * jnp.square(p_std))) - 0.5, axis=-1)

# ============================== Other Functions ============================= #

@jax.jit
def cos_sim(x: jax.Array, y: jax.Array):
    return jnp.dot(x, y) / (jnp.linalg.norm(x) * jnp.linalg.norm(y))

@partial(jax.jit, static_argnums=(2,))
def sample_dist(mean: jax.Array, std: jax.Array, deterministic: bool):
    if deterministic:
        return mean
    return mean + std * jax.random.normal(hk.next_rng_key(), std.shape)

@jax.jit
def apply_cond(noise: jax.Array, cond: jax.Array) -> jax.Array:
    dim = cond.shape[-1]
    # constrain first obs of planning with current obs
    if len(noise.shape) == 2:
        return noise.at[:,:dim].set(cond)
    elif len(noise.shape) == 3:
        return noise.at[:,0,:dim].set(cond)
    else:
        raise NotImplementedError
